#!/usr/bin/env python
# encoding: utf-8

import torchaudio
import torch
import torch.nn as nn
from torch.nn import functional as F

from scipy import signal
import numpy as np

class PreEmphasis(torch.nn.Module):

    def __init__(self, coef: float = 0.97):
        super().__init__()
        self.coef = coef
        # make kernel
        # In pytorch, the convolution operation uses cross-correlation. So, filter is flipped.
        self.register_buffer(
            'flipped_filter', torch.FloatTensor([-self.coef, 1.]).unsqueeze(0).unsqueeze(0)
        )

    def forward(self, inputs: torch.tensor) -> torch.tensor:
        assert len(inputs.size()) == 2, 'The number of dimensions of inputs tensor must be 2!'
        # reflect padding to match lengths of in/out
        inputs = inputs.unsqueeze(1)
        inputs = F.pad(inputs, (1, 0), 'reflect')
        return F.conv1d(inputs, self.flipped_filter).squeeze(1)


def gaussian(window_size, sigma):
    def gauss_fcn(x):
        return -(x - window_size // 2)**2 / float(2 * sigma**2)
    gauss = torch.stack(
        [torch.exp(torch.tensor(gauss_fcn(x))) for x in range(window_size)])
    return gauss / gauss.sum()


def get_gaussian_kernel(ksize: int, sigma: float) -> torch.Tensor:
    r"""Function that returns Gaussian filter coefficients.

    Args:
        ksize (int): filter size. It should be odd and positive.
        sigma (float): gaussian standard deviation.

    Returns:
        Tensor: 1D tensor with gaussian filter coefficients.

    Shape:
        - Output: :math:`(ksize,)`
   """
    if not isinstance(ksize, int) or ksize % 2 == 0 or ksize <= 0:
        raise TypeError("ksize must be an odd positive integer. Got {}"
                        .format(ksize))
    window_1d: torch.Tensor = gaussian(ksize, sigma)
    return window_1d


class Guassian_Filter(nn.Module):
    def __init__(self, ksize, sigma):
        """Guassian_Filter
        """
        super(Guassian_Filter, self).__init__()
        self.kernel = get_gaussian_kernel(5, 1).unsqueeze(0).unsqueeze(0).cuda()

    def forward(self, x):
        """Computes Temporal Average Pooling Module
        Args:
            x (torch.Tensor): Input tensor (#batch, channels, waveform_length).
        Returns:
            torch.Tensor: Output tensor (#batch, channels, waveform_length)
        """
        x = F.conv1d(x, self.kernel)
        return x


def gaussian_filter(sig, kernel_size=5, std=1):
    kernel = signal.windows.gaussian(kernel_size, std)
    kernel = torch.FloatTensor(kernel/float(kernel_size)).cuda().unsqueeze(0).unsqueeze(0)
    sig = sig.unsqueeze(0).unsqueeze(0)
    sig = F.conv1d(sig, kernel, padding=2)
    return sig.reshape(1, -1)

def mean_filter(sig, kernel_size=5):
    sig = sig.reshape(-1)
    kernel = np.ones(kernel_size)
    kernel = torch.FloatTensor(kernel/float(kernel_size)).cuda().unsqueeze(0).unsqueeze(0)
    sig = sig.unsqueeze(0).unsqueeze(0)
    sig = F.conv1d(sig, kernel)
    return sig.reshape(1, -1)

def median_filter(sig, kernel_size=5):
    assert kernel_size % 2 == 1
    sig = sig.cpu().detach().numpy()
    sig = sig.reshape(-1)
    sig = signal.medfilt(sig, kernel_size)
    sig = torch.FloatTensor(sig).cuda()
    return sig.reshape(1, -1)
